import os
import argparse
import torch
from tqdm import tqdm
from utils import ParamDiffAug
from utils import setup_seed, DiffAugment, test_accuracy, get_preparation
from torchvision import transforms


def args_parser():
    parser = argparse.ArgumentParser(description='train a shadow model')
    parser.add_argument('--exp_id', type=int, required=True, help='the index of the shadow model')
    parser.add_argument('--num_shadow', type=int, default=None, help='total number of shadow models')
    parser.add_argument('--num_canaries', type=int, default=None, help='number of mislabeled canaries')
    parser.add_argument('--syn_data_path', type=str, default=None, help='the folder path of synthetic data')
    parser.add_argument('--data_path', type=str, default='data', help='the path to the cifar10 dataset')
    parser.add_argument('--lira_path', type=str, required=True, help='the folder path to save the LiRA results')
    parser.add_argument('--epochs', type=int, default=200, help="number of epochs to train the shadow models")
    parser.add_argument('--lr_net', type=float, default=0.1, help="learning rate for training")
    parser.add_argument('--model_type', type=str, choices=['ConvNet', 'ResNet18', 'ResNet18BN'], help='The model type to use')
    parser.add_argument('--method', type=str, choices=['cifar10', 'random', 'forgetting', 'DM', 'DSA', 'MTT', 'DATM', 'Diffusion'], help='method of coreset selection or generating synthetic data')
    parser.add_argument('--use_dd_aug', action='store_true', help='whether to use transforms in DD')
    parser.add_argument('--save_freq', type=int, default=None, help='frequency to save the model')
    # for cifar10, random, and forgetting
    parser.add_argument('--avg_case', action='store_true', default=False, help='use average case in-out split')
    # for diffusion only
    parser.add_argument('--trainset_size', type=int, default=None, help='number of synthetic data')
    # for coreset only
    parser.add_argument('--num_coreset', type=int, default=None, help='number of coreset samples')

    args = parser.parse_args()

    # TODO: add more checks
    if args.method == 'Diffusion' and args.trainset_size is None:
        parser.error("--method 'Diffusion' requires --trainset_size to be specified")
    if args.method == 'cifar10' and args.syn_data_path is not None:
        parser.error(f"--syn_data_path is not expected to be specified when --method is {args.method}")
    if args.method == 'cifar10' and args.num_shadow is None:
        parser.error(f"--method {args.method} requires --num_shadow to be specified'")
    if args.method == 'cifar10' and args.num_canaries is None:
        parser.error(f"--method {args.method} requires --num_canaries to be specified'")
    if (args.method == 'random' or args.method == 'forgetting') and args.num_coreset is None:
        parser.error(f"--method {args.method} requires --num_coreset to be specified'")
    
    return args


def main():
    args = args_parser()
    args.device = 'cuda' if torch.cuda.is_available() else 'cpu'
    print(f"Using device: {args.device}")

    # create folders to save the LiRA results
    suffix = f'{args.model_type}_dd_aug' if args.use_dd_aug else f'{args.model_type}'
    ckpt_path = os.path.join(args.lira_path, f"ckpts_{suffix}")
    os.makedirs(ckpt_path, exist_ok=True) 

    # setup random seeds
    setup_seed(1)

    # augmentation
    if args.use_dd_aug:
        dsa_param = ParamDiffAug()
        dsa_strategy = 'color_crop_cutout_flip_scale_rotate'

    # get configurations
    train_loader, test_loader, model, optimizer, criterion, scheduler = get_preparation(args)

    # start training
    best_acc = 0
    for epoch in tqdm(range(args.epochs)):
        model.train()
        for batch_idx, (inputs, labels) in enumerate(train_loader):
            inputs, labels = inputs.to(args.device), labels.to(args.device)
            if args.use_dd_aug:
                inputs = DiffAugment(inputs, dsa_strategy, param=dsa_param)
            optimizer.zero_grad()
            outputs = model(inputs)
            loss = criterion(outputs, labels)
            loss.backward()
            optimizer.step()
            scheduler.step()

        acc, test_loss = test_accuracy(model, test_loader)
        if acc > best_acc:
            best_acc = acc
            torch.save(model, os.path.join(ckpt_path, f"model_best_{args.exp_id}.pt"))

        if args.save_freq is not None: 
            if epoch % args.save_freq == 0:
                torch.save(model, os.path.join(ckpt_path, f"model_epoch_{epoch}_{args.exp_id}.pt"))

    torch.save(model, os.path.join(ckpt_path, f"model_last_{args.exp_id}.pt"))
    print("Best accuracy: ", best_acc)
    print("Last accuracy: ", acc)


if __name__ == '__main__':
    main()





